#include <torch/extension.h>
#include "ATen/ATen.h"
typedef at::BFloat16 bf16;

void cuda_forward(int B,
                  int T,
                  int C,
                  int H,
                  bf16* r,
                  bf16* k,
                  bf16* v,
                  float* w,
                  bf16* u,
                  bf16* y);
void cuda_backward(int B,
                   int T,
                   int C,
                   int H,
                   bf16* r,
                   bf16* k,
                   bf16* v,
                   float* w,
                   bf16* u,
                   bf16* gy,
                   bf16* gr,
                   bf16* gk,
                   bf16* gv,
                   bf16* gw,
                   bf16* gu);

void forward(int64_t B,
             int64_t T,
             int64_t C,
             int64_t H,
             torch::Tensor& r,
             torch::Tensor& k,
             torch::Tensor& v,
             torch::Tensor& w,
             torch::Tensor& u,
             torch::Tensor& y)
{
    cuda_forward(B,
                 T,
                 C,
                 H,
                 r.data_ptr<bf16>(),
                 k.data_ptr<bf16>(),
                 v.data_ptr<bf16>(),
                 w.data_ptr<float>(),
                 u.data_ptr<bf16>(),
                 y.data_ptr<bf16>());
}
void backward(int64_t B,
              int64_t T,
              int64_t C,
              int64_t H,
              torch::Tensor& r,
              torch::Tensor& k,
              torch::Tensor& v,
              torch::Tensor& w,
              torch::Tensor& u,
              torch::Tensor& gy,
              torch::Tensor& gr,
              torch::Tensor& gk,
              torch::Tensor& gv,
              torch::Tensor& gw,
              torch::Tensor& gu)
{
    cuda_backward(B,
                  T,
                  C,
                  H,
                  r.data_ptr<bf16>(),
                  k.data_ptr<bf16>(),
                  v.data_ptr<bf16>(),
                  w.data_ptr<float>(),
                  u.data_ptr<bf16>(),
                  gy.data_ptr<bf16>(),
                  gr.data_ptr<bf16>(),
                  gk.data_ptr<bf16>(),
                  gv.data_ptr<bf16>(),
                  gw.data_ptr<bf16>(),
                  gu.data_ptr<bf16>());
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
    m.def("forward", &forward, "wkv6 forward");
    m.def("backward", &backward, "wkv6 backward");
}

TORCH_LIBRARY(wkv6, m)
{
    m.def("forward", forward);
    m.def("backward", backward);
}
